"""
将太美标注工具所标注的数据格式转换为PURE的训练格式
"""

import os
import json
import re

from tqdm import tqdm


def get_fields(entities, connections, start, end):
    curr_ner = []
    curr_relations = []
    useful_ner_dict = {}
    for entity_dict in entities:
        if entity_dict['startIndex'] >= start and entity_dict['endIndex'] <= end:
            category = entity_dict['categoryId']
            ner_item = [entity_dict['startIndex'],
                        entity_dict['endIndex'] - 1,
                        id2entity[category]]
            curr_ner.append(ner_item)
            useful_ner_dict[entity_dict['id']] = ner_item[:2]

    for relation_dict in connections:
        if relation_dict['fromId'] in useful_ner_dict and relation_dict['toId'] in useful_ner_dict:
            category = relation_dict['categoryId']
            relation_item = [*useful_ner_dict[relation_dict['fromId']],
                             *useful_ner_dict[relation_dict['toId']],
                             id2relation[category]]
            curr_relations.append(relation_item)

    return curr_ner, curr_relations


def split_text_re(text):
    data = []
    start = 0
    for m in end_char_re.finditer(text):
        data.append((text[start:m.end()], start, m.end()))
        start = m.end()
    else:
        if start < len(text):
            data.append((text[start:], start, len(text)))

    return data


def convert_data_format(mode, data_dir):
    files = sorted(os.listdir(data_dir))
    max_len = 0
    items = []
    for i, file in enumerate(tqdm(files, desc=mode), start=1):
        with open(os.path.join(data_dir, file)) as fr:
            data = json.load(fr)

        raw_text = data['annotations'][0]['content']
        labels = data['annotations'][0]['labels']
        connections = data['annotations'][0]['connections']
        for sub_text, start, end in split_text_re(raw_text):
            if not sub_text.strip() or len(sub_text) < 2:
                continue

            curr_ner, curr_relations = get_fields(labels, connections, start, end)
            entities = []
            for ner_s, ner_e, ner_tag in curr_ner:
                start_idx = ner_s - start
                end_idx = ner_e - start
                entity = sub_text[start_idx: end_idx + 1]
                entities.append({
                    "start_idx": start_idx,
                    "end_idx": end_idx,
                    "type": ner_tag,
                    "entity": entity
                })

            # todo: update
            item = {
                "id": f"{mode}-{i}",
                "text": sub_text,
                # "entity_list": [[ner_s - start, ner_e - start + 1, ner_tag] for ner_s, ner_e, ner_tag in curr_ner],
                "entities": entities
            }
            items.append(item)
            max_len = max(max_len, len(sub_text))

    with open(os.path.join(save_dir, f"{mode}.json"), 'w') as fw:
        fw.write(json.dumps(items, ensure_ascii=False, indent=2))
    print(f"{mode}'s max length: {max_len}")


train_dir = "/home/zhoazj/Downloads/pred_tm/image_report/20220310/train"
valid_dir = "/home/zhoazj/Downloads/pred_tm/image_report/20220310/valid"
test_dir = "/home/zhoazj/Downloads/pred_tm/image_report/20220310/test"
label_p = "/home/zhoazj/Downloads/pred_tm/image_report/labelCategories_CT1.json"
connection_p = "/home/zhoazj/Downloads/pred_tm/image_report/connectionCategories_CT1.json"
save_dir = "/home/zhoazj/Desktop/codes/gitee/BERT-NER-Pytorch/datasets/emr-image-report"

with open(label_p, encoding="utf-8") as fr:
    id2entity = {item["id"]: item["text"].strip() for item in json.load(fr)["labelCategories"]}

with open(connection_p, encoding="utf-8") as fr:
    id2relation = {item["id"]: item["text"].strip() for item in json.load(fr)["connectionCategories"]}

end_char_re = re.compile(r"[。！？!?]")
convert_data_format("train", train_dir)
convert_data_format("dev", valid_dir)
convert_data_format("test", test_dir)
